import os
import pickle
import numpy as np
import random
import torch
import wandb
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader
from copy import deepcopy
# Specify CUDA_VISIBLE_DEVICES in the command, 
# e.g., CUDA_VISIBLE_DEVICES=0,1 nohup bash exp_on_b7server_0.sh
# ---------------------- only for debug -----------------------
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# -------------------------------------------------------------

from src.utils_others import *
from src.utils_data import combine_two_batch
from src.utils_data_ic import IC_Generate_Batch, Continual_IC_Dataset, ImageDataset
# from src.utils_data_tc import TC_collate_fn, Continual_TC_Dataset
# from src.utils_data_ner import NER_collate_fn, Continual_NER_Dataset
from src.trainer import BaseTrainer
from src.config import get_params
from src.utils_eval import *


def main_cl(params):
    # ------------------------------------------------------------------------------------------------------------------------=====
    # Using Fixed Random Seed
    if params.seed:
        random.seed(params.seed)
        np.random.seed(params.seed)
        torch.manual_seed(params.seed)
        torch.cuda.manual_seed(params.seed)
        torch.backends.cudnn.deterministic = True
    # Initialize Experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)
    logger.info(params.__dict__)
    # Dataloader
    if params.task_name == 'NER':
        pass
        # CL_dataset = Continual_NER_Dataset(dataset=params.dataset,
        #                                     batch_size=params.batch_size, 
        #                                     schema=params.schema,
        #                                     is_mix_er=params.is_mix_er)
    elif params.task_name == 'TC':
        pass
        # CL_dataset = Continual_TC_Dataset(dataset=params.dataset,
        #                                     batch_size=params.batch_size,
        #                                     is_mix_er=params.is_mix_er)
    elif params.task_name == 'IC':
        CL_dataset = Continual_IC_Dataset(dataset=params.dataset,
                                            batch_size=params.batch_size,
                                            class_ft=params.class_ft,
                                            class_pt=params.class_pt,
                                            backbone=params.backbone,
                                            is_mix_er=params.is_mix_er,
                                            seed=params.seed)
    else:
        raise NotImplementedError()

    # Trainer 
    trainer = BaseTrainer(params, CL_dataset)
    
    if params.is_tracking:
        cls_center_dict = {}
        encoder_center_dict = {}

    # Initialize wandb
    if params.is_wandb:
        assert wandb is not None, "Wandb not installed, please install it or run without wandb"
        # retry request (handles connection errors, timeouts)
        try_cnt = 0
        while try_cnt<5:
            try:
                wandb.init(project=params.wandb_project, entity=params.wandb_entity, name=params.wandb_name, config=vars(params), settings=wandb.Settings(start_method="fork"))
                break
            except Exception as e:
                print(str(e))
                print("Retrying Connecting wandb...")
                try_cnt+=1
                time.sleep(120)
        params.__setattr__('wandb_url', wandb.run.get_url())

    # ============================================================================
    # Start training
    global_step=0

    # Result Summary Matrix
    #                           Task 0   Task 1   Task 2
    #                        -----------------------------
    #  Learning After Task 0 |        |         |        |
    #                        -----------------------------
    #  Learning After Task 1 |        |         |        |
    #                        -----------------------------
    #  Learning After Task 2 |        |         |        |
    #                        -----------------------------
    if trainer.params.task_name == 'NER':
        result_summary = {
            'micro_f1':-1*np.ones((CL_dataset.NUM_TASK,CL_dataset.NUM_TASK)),
            'macro_f1':-1*np.ones((CL_dataset.NUM_TASK,CL_dataset.NUM_TASK)),
        }
    elif trainer.params.task_name in ['TC','IC']:
        result_summary = {
            'acc':-1*np.ones((CL_dataset.NUM_TASK,CL_dataset.NUM_TASK)),
        }
    else:
        raise NotImplementedError()
        
    for task_id in range(CL_dataset.NUM_TASK):
        logger.info("============================================================================")   
        logger.info("Beggin training the task %d (total %d tasks)"%(task_id, CL_dataset.NUM_TASK))     
        logger.info("============================================================================")
        
        trainer.begin_task(task_id)

        # Init training variables
        dataset_name = '_'.join(CL_dataset.DATASET_LIST)

        best_model_ckpt_name = "best_finetune_dataset_%s_task_id_%d.pth"%(
                                dataset_name, 
                                task_id)
        if task_id==0 and params.first_training_epochs>0:
            training_epochs = params.first_training_epochs
        else:
            training_epochs = params.training_epochs

        no_improvement_num = 0
        best_score = -1
        step = 0
        is_finish = False

        # For calculating forward transfer in continual learning:
        # Evaluate the performance on the next task before training on it
        if task_id>0 and not params.is_debug:
            if trainer.params.task_name == 'NER':
                result_dict = evaluate_all_seen_task_ner(trainer,CL_dataset,task_id,'test')
                result_summary['micro_f1'][task_id-1,task_id] = result_dict['Result_test_mean_mif1']
                result_summary['macro_f1'][task_id-1,task_id] = result_dict['Result_test_mean_maf1']
            elif trainer.params.task_name == 'TC':
                result_dict = evaluate_all_seen_task_tc(trainer,CL_dataset,task_id,'test')
                result_summary['acc'][task_id-1,task_id] = result_dict['Result_test_mean_acc']
            elif trainer.params.task_name == 'IC':
                result_dict = evaluate_all_seen_task_ic(trainer,CL_dataset,task_id,'test')
                result_summary['acc'][task_id-1,task_id] = result_dict['Result_test_mean_acc']
            else:
                raise NotImplementedError()
        # Probing before training
        if params.is_probing:
            if trainer.params.task_name == 'NER':
                probe_result = probe_model(trainer, CL_dataset, task_id, 'test')
                logger.info("Epoch %d, Step %d, Probe result = %s"%(0, global_step, probe_result))
                if params.is_wandb:
                    wandb.log({'PROBE_mif1':probe_result['Result_test_mean_mif1'],
                                'PROBE_maf1':probe_result['Result_test_mean_maf1']},step=global_step)
                # compare with the ability of the current classifier
                result_dict = evaluate_all_seen_task_ner(trainer,CL_dataset,task_id,'test')
                logger.info("Epoch %d, Step %d, test result = %s" % (
                    0, 0, result_dict
                ))
            elif trainer.params.task_name == 'TC':
                probe_result = probe_model(trainer, CL_dataset, task_id, 'test')
                logger.info("Epoch %d, Step %d, Probe result = %s"%(0, global_step, probe_result))
                if params.is_wandb:
                    wandb.log({'PROBE_acc':probe_result['Result_test_mean_acc']},step=global_step)
                # compare with the ability of the current classifier
                result_dict = evaluate_all_seen_task_tc(trainer,CL_dataset,task_id,'test')
                logger.info("Epoch %d, Step %d, test result = %s" % (
                    0, 0, result_dict
                ))
            elif trainer.params.task_name == 'IC':
                probe_result = probe_model(trainer, CL_dataset, task_id, 'test')
                logger.info("Epoch %d, Step %d, Probe result = %s"%(0, global_step, probe_result))
                if params.is_wandb:
                    wandb.log({'PROBE_acc':probe_result['Result_test_mean_acc']},step=global_step)
                # compare with the ability of the current classifier
                result_dict = evaluate_all_seen_task_ic(trainer,CL_dataset,task_id,'test')
                logger.info("Epoch %d, Step %d, test result = %s" % (
                    0, 0, result_dict
                ))
            else:
                raise NotImplementedError()
        # Tracking the update of cls and encoder
        if params.is_tracking:
            cls_center, encoder_center = tracking_model(trainer, CL_dataset, task_id)
            cls_center_dict[task_id] = []
            cls_center_dict[task_id].append(cls_center)
            encoder_center_dict[task_id] = []
            encoder_center_dict[task_id].append(encoder_center)

        if params.is_multitask:
            train_loader = trainer.CL_dataset.get_accum_data_loader(task_id, 'train')
        elif params.is_combine_er and task_id>0:
            if isinstance(CL_dataset.data_loader['train'][task_id].dataset, Subset):
                select_idx_list = deepcopy(CL_dataset.data_loader['train'][task_id].dataset.indices)
                train_dataset_all = deepcopy(CL_dataset.data_loader['train'][task_id].dataset.dataset)
                all_buffer_data = trainer.model.buffer.get_buffer_all(task_id)
                cnt_old = len(train_dataset_all)
                
                if isinstance(train_dataset_all,ImageDataset):
                    buffer_X_all = np.array([_x for _idx, _x, _y in all_buffer_data])
                    train_dataset_all.data = np.concatenate((train_dataset_all.data,buffer_X_all))
                    buffer_Y_all = np.array([_y for _idx, _x, _y in all_buffer_data])
                    train_dataset_all.targets = np.concatenate((train_dataset_all.targets,buffer_Y_all))
                else:
                    buffer_X_all = np.concatenate([np.expand_dims(np.array(_x),0) for _idx, _x, _y in all_buffer_data])
                    train_dataset_all.data = np.concatenate((train_dataset_all.data,buffer_X_all))
                    train_dataset_all.targets.extend([_y for _idx, _x, _y in all_buffer_data])
                cnt_new = len(train_dataset_all)
                select_idx_list.extend(list(range(cnt_old,cnt_new)))
                train_loader = DataLoader(dataset=Subset(train_dataset_all, select_idx_list), 
                                        batch_size=params.batch_size, 
                                        shuffle=True)
            else:
                raise NotImplementedError()
        else:
            train_loader = CL_dataset.data_loader['train'][task_id]

        for e in range(1, training_epochs+1):

            trainer.begin_epoch(task_id, e)

            trainer.model.train()
            if is_finish:
                break
            logger.info("------------------------ epoch %d ------------------------" % e)
            # loss list
            loss_list, distill_list, ce_list = [], [], []
            # average loss
            mean_loss = 0.0

            for idx, X, y in train_loader:

                if is_finish or (params.is_debug and step>10):
                    break
                X, y = X.cuda(), y.cuda()  

                # experience replay
                if task_id > 0 and (global_step+1)%params.replay_interval == 0:
                    
                    # replay separately
                    if params.is_er:
                        step += 1
                        global_step += 1
                        if params.task_name == 'NER':
                            pass
                            # buffer_idx, buffer_X, buffer_y = NER_collate_fn(trainer.model.buffer.get_buffer_batch(task_id))
                        elif params.task_name == 'TC':
                            pass
                            # buffer_idx, buffer_X, buffer_y = TC_collate_fn(trainer.model.buffer.get_buffer_batch(task_id))
                        elif params.task_name == 'IC':
                            buffer_idx, buffer_X, buffer_y = IC_Generate_Batch(trainer.model.buffer.get_buffer_batch(task_id),train_loader.dataset.dataset.transform)
                        else:
                            raise NotImplementedError()
                        buffer_X, buffer_y = buffer_X.cuda(), buffer_y.cuda()
                        total_loss, ce_loss, distill_loss = trainer.observe_batch(buffer_idx, buffer_X, buffer_y, task_id, e, global_step, is_replay=True)
                        loss_list.append(total_loss)
                        distill_list.append(distill_loss)
                        ce_list.append(ce_loss)

                    # combine with data in the new task
                    elif params.is_mix_er:
                        if params.task_name == 'NER':
                            pass
                            # buffer_idx, buffer_X, buffer_y = NER_collate_fn(trainer.model.buffer.get_buffer_batch(task_id))
                        elif params.task_name == 'TC':
                            pass
                            # buffer_idx, buffer_X, buffer_y = TC_collate_fn(trainer.model.buffer.get_buffer_batch(task_id))
                        elif params.task_name == 'IC':
                            if hasattr(train_loader.dataset,'dataset'):
                                buffer_idx, buffer_X, buffer_y = IC_Generate_Batch(trainer.model.buffer.get_buffer_batch(task_id),train_loader.dataset.dataset.transform)
                            else:
                                buffer_idx, buffer_X, buffer_y = IC_Generate_Batch(trainer.model.buffer.get_buffer_batch(task_id),train_loader.dataset.transform)
                        else:
                            raise NotImplementedError()
                        buffer_X, buffer_y = buffer_X.cuda(), buffer_y.cuda()
                        X, y = combine_two_batch(X, buffer_X, y, buffer_y)
                        if isinstance(idx,list) or isinstance(idx,tuple):
                            idx = list(idx)+[-1]*len(buffer_idx)
                        elif isinstance(idx,torch.Tensor):
                            idx = torch.cat((idx,torch.tensor(buffer_idx)))
                            idx[len(buffer_idx):] = -1
                        else:
                            raise NotImplementedError()

                # Update the step count
                step += 1
                global_step += 1
                if params.is_buffer and e==1 and params.sampling_alg == 'reservior':
                    # Obtain the original image
                    if params.task_name == 'IC':
                        X_origin = [train_loader.dataset.dataset.__getitem__(_idx,is_transform=False)[1] for _idx in idx if _idx!=-1]
                    else:
                        X_origin = X
                    trainer.model.buffer.update_buffer_batch(X_origin, y, task_id) # only update buffer on the first epoch
                total_loss, ce_loss, distill_loss = trainer.observe_batch(idx, X, y, task_id, e, global_step) 
                assert not(np.isnan(total_loss) or np.isinf(total_loss))
                loss_list.append(total_loss)
                distill_list.append(distill_loss)
                ce_list.append(ce_loss)
                
                # Print training information
                if params.info_per_steps>0 and step%params.info_per_steps==0:

                    mean_loss = np.mean(loss_list)
                    mean_distill_loss = np.mean(distill_list) if len(distill_list)>0 else 0
                    mean_ce_loss = np.mean(ce_list) if len(ce_list)>0 else 0

                    logger.info("Epoch %d, Step %d: Total_loss=%.3f, CE_loss=%.3f, Distill_loss=%.3f"%(
                            e, step, mean_loss, mean_ce_loss, mean_distill_loss
                    ))
                    if params.is_wandb:
                        wandb.log({'loss':mean_loss,
                                    'ce_loss':mean_ce_loss,
                                    'distill_loss':mean_distill_loss},step=global_step)    

            trainer.end_epoch(task_id, e)

            # Probing
            if params.is_probing and e%params.probing_interval==0:
                if trainer.params.task_name == 'NER':
                    # probe the ability of the encoder
                    probe_result = probe_model(trainer, CL_dataset, task_id, 'test')
                    logger.info("Epoch %d, Step %d, Probe result = %s"%(e, global_step, probe_result))
                    if params.is_wandb:
                        wandb.log({'PROBE_mif1':probe_result['Result_test_mean_mif1'],
                                    'PROBE_maf1':probe_result['Result_test_mean_maf1']},step=global_step)
                    # compare with the ability of the current classifier
                    result_dict = evaluate_all_seen_task_ner(trainer,CL_dataset,task_id,'test')
                    logger.info("Epoch %d, Step %d, test result = %s" % (
                        e, step, result_dict
                    ))
                elif trainer.params.task_name == 'TC':
                    # probe the ability of the encoder
                    probe_result = probe_model(trainer, CL_dataset, task_id, 'test')
                    logger.info("Epoch %d, Step %d, Probe result = %s"%(e, global_step, probe_result))
                    if params.is_wandb:
                        wandb.log({'PROBE_acc':probe_result['Result_test_mean_acc']},step=global_step)
                    # compare with the ability of the current classifier
                    result_dict = evaluate_all_seen_task_tc(trainer,CL_dataset,task_id,'test')
                    logger.info("Epoch %d, Step %d, test result = %s" % (
                        e, step, result_dict
                    ))
                elif trainer.params.task_name == 'IC':
                    # probe the ability of the encoder
                    probe_result = probe_model(trainer, CL_dataset, task_id, 'test')
                    logger.info("Epoch %d, Step %d, Probe result = %s"%(e, global_step, probe_result))
                    if params.is_wandb:
                        wandb.log({'PROBE_acc':probe_result['Result_test_mean_acc']},step=global_step)
                    # compare with the ability of the current classifier
                    result_dict = evaluate_all_seen_task_ic(trainer,CL_dataset,task_id,'test')
                    logger.info("Epoch %d, Step %d, test result = %s" % (
                        e, step, result_dict
                    ))
                else:
                    raise NotImplementedError()

            # Tracking the update of cls and encoder
            if params.is_tracking and e%params.tracking_interval==0:
                cls_center, encoder_center = tracking_model(trainer, CL_dataset, task_id)
                cls_center_dict[task_id].append(cls_center)
                encoder_center_dict[task_id].append(encoder_center)
                
            # Print training information
            if params.info_per_epochs>0 and e%params.info_per_epochs==0:
                mean_loss = np.mean(loss_list)
                mean_distill_loss = np.mean(distill_list) if len(distill_list)>0 else 0
                mean_ce_loss = np.mean(ce_list) if len(ce_list)>0 else 0
                
                logger.info("Epoch %d, Step %d: Total_loss=%.3f, CE_loss=%.3f, Distill_loss=%.3f"%(
                            e, step, mean_loss, mean_ce_loss, mean_distill_loss
                    ))
                if params.is_wandb:
                    wandb.log({'loss':mean_loss,
                                'ce_loss':mean_ce_loss,
                                'distill_loss':mean_distill_loss},step=global_step)
                # reset the loss lst
                loss_list = []
                distill_list = []
                ce_list = []

            # Save checkpoint 
            if params.save_per_epochs>0 and e%params.save_per_epochs==0:
                trainer.save_model("checkpoint_dataset_%s_task_id_%d_epoch_%d.pth"%(
                                        dataset_name, 
                                        task_id,
                                        e), 
                                    path=params.dump_path)
            # For evaluation
            if e%params.evaluate_interval==0 and (not params.is_debug):
                if trainer.params.task_name == 'NER':
                    mif1, maf1, classf1 = evaluate_current_task_ner(trainer,CL_dataset,task_id,'dev')
                    logger.info("Current Task %d, Epoch %d, Step %d: Dev_micro_f1=%.3f, Dev_macro_f1=%.3f, Dev_f1_each_class=%s" % (
                        task_id, e, step, mif1, maf1, classf1
                    ))
                    if params.is_wandb:
                        wandb.log({'RESULT_dev_mif1_%d'%(task_id):mif1,
                                    'RESULT_dev_maf1_%d'%(task_id):maf1},step=global_step)
                    dev_score = mif1
                elif trainer.params.task_name == 'TC':
                    acc, classacc = evaluate_current_task_tc(trainer,CL_dataset,task_id,'dev')
                    logger.info("Current Task %d, Epoch %d, Step %d: Dev_acc=%.3f, Dev_acc_each_class=%s" % (
                        task_id, e, step, acc, classacc
                    ))
                    if params.is_wandb:
                        wandb.log({'RESULT_dev_acc_%d'%(task_id):acc},step=global_step)
                    dev_score = acc
                elif trainer.params.task_name == 'IC':
                    acc, classacc = evaluate_current_task_ic(trainer,CL_dataset,task_id,'dev')
                    logger.info("Current Task %d, Epoch %d, Step %d: Dev_acc=%.3f, Dev_acc_each_class=%s" % (
                        task_id, e, step, acc, classacc
                    ))
                    if params.is_wandb:
                        wandb.log({'RESULT_dev_acc_%d'%(task_id):acc},step=global_step)
                    dev_score = acc
                else:
                    raise NotImplementedError()
                
                if dev_score > best_score:
                    logger.info("Find better model!!")
                    best_score = dev_score
                    no_improvement_num = 0
                    trainer.save_model(best_model_ckpt_name, path=params.dump_path)
                else:
                    no_improvement_num += 1
                    logger.info("No better model is found (%d/%d)" % (no_improvement_num, params.early_stop))
                if no_improvement_num >= params.early_stop:
                    logger.info("Stop training because no better model is found!!!")
                    is_finish = True

        trainer.end_task(task_id)

        logger.info("Finish training ...")

        # ------------------------------------------------------------------------------------------------------------------------=====
        # testing
        logger.info("Testing...")

        if params.is_debug:
            logger.info('Skip testing in the debug mode')
            continue

        if params.is_use_last_ckpt:
            logger.info('Best model according to dev is not loaded because params.is_use_last_ckpt is True...')
        else:
            trainer.load_model(best_model_ckpt_name, path=params.dump_path)
        trainer.model.cuda()

        # testing
        if trainer.params.task_name == 'NER':
            result_dict = evaluate_all_seen_task_ner(trainer,CL_dataset,task_id,'test',is_mbpa=params.is_mbpa)
            logger.info("Test result: micro_f1=%.3f, macro_f1=%.3f"%(result_dict['Result_test_mean_mif1'], result_dict['Result_test_mean_maf1']))
            logger.info("Test result dict=%s"%(result_dict))
            logger.info("Finish testing task %d"%(task_id))
            if task_id==CL_dataset.NUM_TASK-1:
                random_result = compute_random_result_ner(trainer,CL_dataset,task_id,'test')
            for t_id in range(task_id+1):
                metric_name = 'Result_test_mif1_%d'%(t_id)
                result_summary['micro_f1'][task_id,t_id] = result_dict[metric_name]
                metric_name = 'Result_test_maf1_%d'%(t_id)
                result_summary['macro_f1'][task_id,t_id] = result_dict[metric_name]
            logger.info('Result Summary aftering finish training task %d = %s'%(task_id,result_summary))

        elif trainer.params.task_name == 'TC':
            result_dict = evaluate_all_seen_task_tc(trainer,CL_dataset,task_id,'test',is_mbpa=params.is_mbpa)
            logger.info("Test result: acc=%.3f"%(result_dict['Result_test_mean_acc']))
            logger.info("Test result dict=%s"%(result_dict))
            logger.info("Finish testing task %d"%(task_id))
            if task_id==CL_dataset.NUM_TASK-1:
                random_result = compute_random_result_tc(trainer,CL_dataset,task_id,'test')
            for t_id in range(task_id+1):
                metric_name = 'Result_test_acc_%d'%(t_id)
                result_summary['acc'][task_id,t_id] = result_dict[metric_name]
            logger.info('Result Summary aftering finish training task %d = %s'%(task_id,result_summary))

        elif trainer.params.task_name == 'IC':
            result_dict = evaluate_all_seen_task_ic(trainer,CL_dataset,task_id,'test',is_mbpa=params.is_mbpa)
            logger.info("Test result: acc=%.3f"%(result_dict['Result_test_mean_acc']))
            logger.info("Test result dict=%s"%(result_dict))
            logger.info("Finish testing task %d"%(task_id))
            if task_id==CL_dataset.NUM_TASK-1:
                random_result = compute_random_result_ic(trainer,CL_dataset,task_id,'test')
            for t_id in range(task_id+1):
                metric_name = 'Result_test_acc_%d'%(t_id)
                result_summary['acc'][task_id,t_id] = result_dict[metric_name]
            logger.info('Result Summary aftering finish training task %d = %s'%(task_id,result_summary))
        else:
            raise NotImplementedError()

        if params.is_wandb:
            wandb.log(result_dict)

    if params.is_debug:
        logger.info('Skip testing in the debug mode')
        return

    if trainer.params.task_name == 'NER':
        # Compute Forward and Backward Transfer according to Result Summary for the whole learning process
        fwt_mif1 = compute_forward_transfer(result_summary['micro_f1'], random_result['micro_f1']) 
        bwt_mif1 = compute_backward_transfer(result_summary['micro_f1']) 
        fgt_mif1 = compute_forgetting(result_summary['micro_f1']) 
        fwt_maf1 = compute_forward_transfer(result_summary['macro_f1'], random_result['macro_f1']) 
        bwt_maf1 = compute_backward_transfer(result_summary['macro_f1']) 
        fgt_maf1 = compute_forgetting(result_summary['macro_f1']) 

        logger.info('Fwt mif1 = %.2f; Fwt maf1 = %.2f'%(fwt_mif1,fwt_maf1))
        logger.info('Bwt mif1 = %.2f; Bwt maf1 = %.2f'%(bwt_mif1,bwt_maf1))
        logger.info('Fgt mif1 = %.2f; Fgt maf1 = %.2f'%(fgt_mif1,fgt_maf1))

        if params.is_wandb:
            wandb.log({'fwt_mif1':fwt_mif1,'bwt_mif1':bwt_mif1,'fgt_mif1':fgt_mif1,'fwt_maf1':fwt_maf1,'bwt_maf1':bwt_maf1,'fgt_maf1':fgt_maf1})
            wandb.finish()
    elif trainer.params.task_name in ['TC','IC']:
        # Compute Forward and Backward Transfer according to Result Summary for the whole learning process
        fwt_acc = compute_forward_transfer(result_summary['acc'], random_result['acc']) 
        bwt_acc = compute_backward_transfer(result_summary['acc']) 
        fgt_acc = compute_forgetting(result_summary['acc']) 

        logger.info('Fwt acc = %.2f'%(fwt_acc))
        logger.info('Bwt acc = %.2f'%(bwt_acc))
        logger.info('Fgt acc = %.2f'%(fgt_acc))

        if params.is_wandb:
            wandb.log({'fwt_acc':fwt_acc,'bwt_acc':bwt_acc,'fgt_acc':fgt_acc})
            wandb.finish()
    else:
        raise NotImplementedError()

    # Save tracking results
    if params.is_tracking:
        save_dict = {'cls_center_dict':cls_center_dict,'encoder_center_dict':encoder_center_dict}
        with open(os.path.join(params.dump_path,'tracking_result'), "wb") as f:
            pickle.dump(save_dict, f)

    if not params.save_ckpt:
        for file_name in os.listdir(params.dump_path):
            if file_name[-4:] == '.pth':
                os.remove(os.path.join(params.dump_path,file_name))
    # Only retain the last ckpt
    else:
        for file_name in os.listdir(params.dump_path):
            num_task = CL_dataset.NUM_TASK
            if (file_name[-4:] == '.pth') and ('%d.pth'%(num_task-1) not in file_name):
                os.remove(os.path.join(params.dump_path,file_name))
        
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

if __name__ == "__main__":
    params = get_params()
    if params.seed is not None:
        set_random_seed(seed=params.seed)
    main_cl(params)
